{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train GPT on addition\n",
    "\n",
    "Train a GPT model on a dedicated addition dataset to see if a Transformer can learn to add."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set up logging\n",
    "import logging\n",
    "logging.basicConfig(\n",
    "        format=\"%(asctime)s - %(levelname)s - %(name)s -   %(message)s\",\n",
    "        datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
    "        level=logging.INFO,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make deterministic\n",
    "from mingpt.utils import set_seed\n",
    "set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "\n",
    "class AdditionDataset(Dataset):\n",
    "    \"\"\"\n",
    "    Returns addition problems of up to some number of digits in the inputs. Recall\n",
    "    that all GPT cares about are sequences of integers, and completing them according to\n",
    "    patterns in the data. Therefore, we have to somehow encode addition problems\n",
    "    as a sequence of integers.\n",
    "    \n",
    "    The sum of two n-digit numbers gives a third up to (n+1)-digit number. So our\n",
    "    encoding will simply be the n-digit first number, n-digit second number, \n",
    "    and (n+1)-digit result, all simply concatenated together. Because each addition\n",
    "    problem is so structured, there is no need to bother the model with encoding\n",
    "    +, =, or other tokens. Each possible sequence has the same length, and simply\n",
    "    contains the raw digits of the addition problem.\n",
    "    \n",
    "    As a few examples, the 2-digit problems:\n",
    "    - 85 + 50 = 135 becomes the sequence [8, 5, 5, 0, 1, 3, 5]\n",
    "    - 6 + 39 = 45 becomes the sequence [0, 6, 3, 9, 0, 4, 5]\n",
    "    etc.\n",
    "    \n",
    "    We will also only train GPT on the final (n+1)-digits because the first\n",
    "    two n-digits are always assumed to be given. So when we give GPT an exam later,\n",
    "    we will e.g. feed it the sequence [0, 6, 3, 9], which encodes that we'd like\n",
    "    to add 6 + 39, and hope that the model completes the integer sequence with [0, 4, 5]\n",
    "    in 3 sequential steps.\n",
    "    \n",
    "    fun exercise: does it help if the result is asked to be produced in reverse order?\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, ndigit, split):\n",
    "        self.split = split # train/test\n",
    "        self.ndigit = ndigit\n",
    "        self.vocab_size = 10 # 10 possible digits 0..9\n",
    "        # +1 due to potential carry overflow, but then -1 because very last digit doesn't plug back\n",
    "        self.block_size = ndigit + ndigit + ndigit + 1 - 1\n",
    "        \n",
    "        # split up all addition problems into either training data or test data\n",
    "        num = (10**self.ndigit)**2 # total number of possible combinations\n",
    "        r = np.random.RandomState(1337) # make deterministic\n",
    "        perm = r.permutation(num)\n",
    "        num_test = min(int(num*0.2), 1000) # 20% of the whole dataset, or only up to 1000\n",
    "        self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.ixes.size\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # given a problem index idx, first recover the associated a + b\n",
    "        idx = self.ixes[idx]\n",
    "        nd = 10**self.ndigit\n",
    "        a = idx // nd\n",
    "        b = idx %  nd\n",
    "        c = a + b\n",
    "        render = f'%0{self.ndigit}d%0{self.ndigit}d%0{self.ndigit+1}d' % (a,b,c) # e.g. 03+25=28 becomes \"0325028\" \n",
    "        dix = [int(s) for s in render] # convert each character to its token index\n",
    "        # x will be input to GPT and y will be the associated expected outputs\n",
    "        x = torch.tensor(dix[:-1], dtype=torch.long)\n",
    "        y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence\n",
    "        y[:self.ndigit*2-1] = -100 # we will only train in the output locations. -100 will mask loss to zero\n",
    "        return x, y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a dataset for e.g. 2-digit addition\n",
    "ndigit = 2\n",
    "train_dataset = AdditionDataset(ndigit=ndigit, split='train')\n",
    "test_dataset = AdditionDataset(ndigit=ndigit, split='test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([4, 7, 1, 7, 0, 6]), tensor([-100, -100, -100,    0,    6,    4]))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset[0] # sample a training instance just to see what one raw example looks like"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "08/16/2020 23:47:41 - INFO - mingpt.model -   number of parameters: 4.001280e+05\n"
     ]
    }
   ],
   "source": [
    "from mingpt.model import GPT, GPTConfig, GPT1Config\n",
    "\n",
    "# initialize a baby GPT model\n",
    "mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, \n",
    "                  n_layer=2, n_head=4, n_embd=128)\n",
    "model = GPT(mconf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/18 [00:00<?, ?it/s]/apcv/shared/conda-envs/apcv-6244e1d-566/lib/python3.8/site-packages/torch/nn/parallel/_functions.py:61: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
      "  warnings.warn('Was asked to gather along dimension 0, but all '\n",
      "epoch 1 iter 17: train loss 1.74049. lr 5.994512e-04: 100%|██████████| 18/18 [00:30<00:00,  1.70s/it]\n",
      "08/16/2020 23:48:16 - INFO - mingpt.trainer -   test loss: 1.693525\n",
      "epoch 2 iter 17: train loss 1.50974. lr 5.977197e-04: 100%|██████████| 18/18 [00:01<00:00, 11.61it/s]\n",
      "08/16/2020 23:48:18 - INFO - mingpt.trainer -   test loss: 1.466473\n",
      "epoch 3 iter 17: train loss 1.31133. lr 5.948114e-04: 100%|██████████| 18/18 [00:01<00:00, 11.45it/s]\n",
      "08/16/2020 23:48:20 - INFO - mingpt.trainer -   test loss: 1.256615\n",
      "epoch 4 iter 17: train loss 1.22379. lr 5.907379e-04: 100%|██████████| 18/18 [00:01<00:00, 11.50it/s]\n",
      "08/16/2020 23:48:21 - INFO - mingpt.trainer -   test loss: 1.160792\n",
      "epoch 5 iter 17: train loss 1.14308. lr 5.855153e-04: 100%|██████████| 18/18 [00:01<00:00, 11.63it/s]\n",
      "08/16/2020 23:48:23 - INFO - mingpt.trainer -   test loss: 1.091487\n",
      "epoch 6 iter 17: train loss 1.09970. lr 5.791641e-04: 100%|██████████| 18/18 [00:01<00:00, 11.56it/s]\n",
      "08/16/2020 23:48:25 - INFO - mingpt.trainer -   test loss: 1.050111\n",
      "epoch 7 iter 17: train loss 1.08481. lr 5.717095e-04: 100%|██████████| 18/18 [00:01<00:00, 11.53it/s]\n",
      "08/16/2020 23:48:26 - INFO - mingpt.trainer -   test loss: 1.037456\n",
      "epoch 8 iter 17: train loss 1.03496. lr 5.631810e-04: 100%|██████████| 18/18 [00:01<00:00, 11.59it/s]\n",
      "08/16/2020 23:48:28 - INFO - mingpt.trainer -   test loss: 0.997156\n",
      "epoch 9 iter 17: train loss 0.98606. lr 5.536122e-04: 100%|██████████| 18/18 [00:01<00:00, 11.67it/s]\n",
      "08/16/2020 23:48:30 - INFO - mingpt.trainer -   test loss: 0.836543\n",
      "epoch 10 iter 17: train loss 0.59589. lr 5.430411e-04: 100%|██████████| 18/18 [00:01<00:00, 12.80it/s]\n",
      "08/16/2020 23:48:31 - INFO - mingpt.trainer -   test loss: 0.438013\n",
      "epoch 11 iter 17: train loss 0.50257. lr 5.315093e-04: 100%|██████████| 18/18 [00:01<00:00, 12.99it/s]\n",
      "08/16/2020 23:48:33 - INFO - mingpt.trainer -   test loss: 0.343370\n",
      "epoch 12 iter 17: train loss 0.44096. lr 5.190624e-04: 100%|██████████| 18/18 [00:01<00:00, 12.08it/s]\n",
      "08/16/2020 23:48:34 - INFO - mingpt.trainer -   test loss: 0.277625\n",
      "epoch 13 iter 17: train loss 0.37445. lr 5.057497e-04: 100%|██████████| 18/18 [00:01<00:00, 12.84it/s]\n",
      "08/16/2020 23:48:36 - INFO - mingpt.trainer -   test loss: 0.236511\n",
      "epoch 14 iter 17: train loss 0.31269. lr 4.916238e-04: 100%|██████████| 18/18 [00:01<00:00, 12.90it/s]\n",
      "08/16/2020 23:48:37 - INFO - mingpt.trainer -   test loss: 0.207689\n",
      "epoch 15 iter 17: train loss 0.34095. lr 4.767405e-04: 100%|██████████| 18/18 [00:01<00:00, 12.74it/s]\n",
      "08/16/2020 23:48:39 - INFO - mingpt.trainer -   test loss: 0.165566\n",
      "epoch 16 iter 17: train loss 0.25957. lr 4.611586e-04: 100%|██████████| 18/18 [00:01<00:00, 12.69it/s]\n",
      "08/16/2020 23:48:40 - INFO - mingpt.trainer -   test loss: 0.123080\n",
      "epoch 17 iter 17: train loss 0.23488. lr 4.449397e-04: 100%|██████████| 18/18 [00:01<00:00, 12.85it/s]\n",
      "08/16/2020 23:48:42 - INFO - mingpt.trainer -   test loss: 0.091252\n",
      "epoch 18 iter 17: train loss 0.20269. lr 4.281479e-04: 100%|██████████| 18/18 [00:01<00:00, 12.72it/s]\n",
      "08/16/2020 23:48:43 - INFO - mingpt.trainer -   test loss: 0.078601\n",
      "epoch 19 iter 17: train loss 0.19535. lr 4.108497e-04: 100%|██████████| 18/18 [00:01<00:00, 12.78it/s]\n",
      "08/16/2020 23:48:45 - INFO - mingpt.trainer -   test loss: 0.055412\n",
      "epoch 20 iter 17: train loss 0.16152. lr 3.931133e-04: 100%|██████████| 18/18 [00:01<00:00, 12.66it/s]\n",
      "08/16/2020 23:48:46 - INFO - mingpt.trainer -   test loss: 0.051874\n",
      "epoch 21 iter 17: train loss 0.14061. lr 3.750088e-04: 100%|██████████| 18/18 [00:01<00:00, 12.84it/s]\n",
      "08/16/2020 23:48:48 - INFO - mingpt.trainer -   test loss: 0.044502\n",
      "epoch 22 iter 17: train loss 0.16309. lr 3.566079e-04: 100%|██████████| 18/18 [00:01<00:00, 12.67it/s]\n",
      "08/16/2020 23:48:49 - INFO - mingpt.trainer -   test loss: 0.036376\n",
      "epoch 23 iter 17: train loss 0.14411. lr 3.379832e-04: 100%|██████████| 18/18 [00:01<00:00, 13.21it/s]\n",
      "08/16/2020 23:48:51 - INFO - mingpt.trainer -   test loss: 0.029843\n",
      "epoch 24 iter 17: train loss 0.12110. lr 3.192084e-04: 100%|██████████| 18/18 [00:01<00:00, 12.74it/s]\n",
      "08/16/2020 23:48:52 - INFO - mingpt.trainer -   test loss: 0.025040\n",
      "epoch 25 iter 17: train loss 0.11360. lr 3.003577e-04: 100%|██████████| 18/18 [00:01<00:00, 12.77it/s]\n",
      "08/16/2020 23:48:54 - INFO - mingpt.trainer -   test loss: 0.023500\n",
      "epoch 26 iter 17: train loss 0.13910. lr 2.815056e-04: 100%|██████████| 18/18 [00:01<00:00, 12.78it/s]\n",
      "08/16/2020 23:48:55 - INFO - mingpt.trainer -   test loss: 0.022606\n",
      "epoch 27 iter 17: train loss 0.07931. lr 2.627266e-04: 100%|██████████| 18/18 [00:01<00:00, 12.74it/s]\n",
      "08/16/2020 23:48:57 - INFO - mingpt.trainer -   test loss: 0.015403\n",
      "epoch 28 iter 17: train loss 0.09684. lr 2.440948e-04: 100%|██████████| 18/18 [00:01<00:00, 11.92it/s]\n",
      "08/16/2020 23:48:58 - INFO - mingpt.trainer -   test loss: 0.015245\n",
      "epoch 29 iter 17: train loss 0.09055. lr 2.256841e-04: 100%|██████████| 18/18 [00:01<00:00, 12.77it/s]\n",
      "08/16/2020 23:49:00 - INFO - mingpt.trainer -   test loss: 0.012647\n",
      "epoch 30 iter 17: train loss 0.08837. lr 2.075671e-04: 100%|██████████| 18/18 [00:01<00:00, 12.59it/s]\n",
      "08/16/2020 23:49:01 - INFO - mingpt.trainer -   test loss: 0.011611\n",
      "epoch 31 iter 17: train loss 0.08425. lr 1.898155e-04: 100%|██████████| 18/18 [00:01<00:00, 12.43it/s]\n",
      "08/16/2020 23:49:03 - INFO - mingpt.trainer -   test loss: 0.009952\n",
      "epoch 32 iter 17: train loss 0.10772. lr 1.724993e-04: 100%|██████████| 18/18 [00:01<00:00, 12.40it/s]\n",
      "08/16/2020 23:49:05 - INFO - mingpt.trainer -   test loss: 0.008648\n",
      "epoch 33 iter 17: train loss 0.07272. lr 1.556871e-04: 100%|██████████| 18/18 [00:01<00:00, 12.57it/s]\n",
      "08/16/2020 23:49:06 - INFO - mingpt.trainer -   test loss: 0.010154\n",
      "epoch 34 iter 17: train loss 0.05550. lr 1.394453e-04: 100%|██████████| 18/18 [00:01<00:00, 12.47it/s]\n",
      "08/16/2020 23:49:08 - INFO - mingpt.trainer -   test loss: 0.007668\n",
      "epoch 35 iter 17: train loss 0.05451. lr 1.238381e-04: 100%|██████████| 18/18 [00:01<00:00, 12.59it/s]\n",
      "08/16/2020 23:49:09 - INFO - mingpt.trainer -   test loss: 0.008095\n",
      "epoch 36 iter 17: train loss 0.09133. lr 1.089272e-04: 100%|██████████| 18/18 [00:01<00:00, 12.39it/s]\n",
      "08/16/2020 23:49:11 - INFO - mingpt.trainer -   test loss: 0.006615\n",
      "epoch 37 iter 17: train loss 0.06825. lr 9.477150e-05: 100%|██████████| 18/18 [00:01<00:00, 12.27it/s]\n",
      "08/16/2020 23:49:12 - INFO - mingpt.trainer -   test loss: 0.005874\n",
      "epoch 38 iter 17: train loss 0.05798. lr 8.142699e-05: 100%|██████████| 18/18 [00:01<00:00, 12.49it/s]\n",
      "08/16/2020 23:49:14 - INFO - mingpt.trainer -   test loss: 0.005701\n",
      "epoch 39 iter 17: train loss 0.06975. lr 6.894639e-05: 100%|██████████| 18/18 [00:01<00:00, 12.88it/s]\n",
      "08/16/2020 23:49:15 - INFO - mingpt.trainer -   test loss: 0.005469\n",
      "epoch 40 iter 17: train loss 0.06070. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.80it/s]\n",
      "08/16/2020 23:49:17 - INFO - mingpt.trainer -   test loss: 0.005307\n",
      "epoch 41 iter 17: train loss 0.06378. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.60it/s]\n",
      "08/16/2020 23:49:18 - INFO - mingpt.trainer -   test loss: 0.005681\n",
      "epoch 42 iter 17: train loss 0.04885. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.81it/s]\n",
      "08/16/2020 23:49:20 - INFO - mingpt.trainer -   test loss: 0.005456\n",
      "epoch 43 iter 17: train loss 0.06409. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.81it/s]\n",
      "08/16/2020 23:49:21 - INFO - mingpt.trainer -   test loss: 0.004907\n",
      "epoch 44 iter 17: train loss 0.07563. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.69it/s]\n",
      "08/16/2020 23:49:23 - INFO - mingpt.trainer -   test loss: 0.004650\n",
      "epoch 45 iter 17: train loss 0.03149. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.79it/s]\n",
      "08/16/2020 23:49:24 - INFO - mingpt.trainer -   test loss: 0.004626\n",
      "epoch 46 iter 17: train loss 0.07037. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.86it/s]\n",
      "08/16/2020 23:49:26 - INFO - mingpt.trainer -   test loss: 0.004147\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 47 iter 17: train loss 0.07650. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.82it/s]\n",
      "08/16/2020 23:49:27 - INFO - mingpt.trainer -   test loss: 0.004611\n",
      "epoch 48 iter 17: train loss 0.06342. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.63it/s]\n",
      "08/16/2020 23:49:29 - INFO - mingpt.trainer -   test loss: 0.004083\n",
      "epoch 49 iter 17: train loss 0.12429. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.69it/s]\n",
      "08/16/2020 23:49:30 - INFO - mingpt.trainer -   test loss: 0.004081\n",
      "epoch 50 iter 17: train loss 0.04616. lr 6.000000e-05: 100%|██████████| 18/18 [00:01<00:00, 12.19it/s]\n",
      "08/16/2020 23:49:32 - INFO - mingpt.trainer -   test loss: 0.003922\n"
     ]
    }
   ],
   "source": [
    "from mingpt.trainer import Trainer, TrainerConfig\n",
    "\n",
    "# initialize a trainer instance and kick off training\n",
    "tconf = TrainerConfig(max_epochs=50, batch_size=512, learning_rate=6e-4,\n",
    "                      lr_decay=True, warmup_tokens=1024, final_tokens=50*len(train_dataset)*(ndigit+1),\n",
    "                      num_workers=4)\n",
    "trainer = Trainer(model, train_dataset, test_dataset, tconf)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now let's give the trained model an addition exam\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "from mingpt.utils import sample\n",
    "\n",
    "def give_exam(dataset, batch_size=32, max_batches=-1):\n",
    "    \n",
    "    results = []\n",
    "    loader = DataLoader(dataset, batch_size=batch_size)\n",
    "    for b, (x, y) in enumerate(loader):\n",
    "        x = x.to(trainer.device)\n",
    "        d1d2 = x[:, :ndigit*2]\n",
    "        d1d2d3 = sample(model, d1d2, ndigit+1)\n",
    "        d3 = d1d2d3[:, -(ndigit+1):]\n",
    "        factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device)\n",
    "        # decode the integers from individual digits\n",
    "        d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1)\n",
    "        d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1)\n",
    "        d3i_pred = (d3 * factors).sum(1)\n",
    "        d3i_gt = d1i + d2i\n",
    "        correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line, lol\n",
    "        for i in range(x.size(0)):\n",
    "            results.append(int(correct[i]))\n",
    "            judge = 'YEP!!!' if correct[i] else 'NOPE'\n",
    "            if not correct[i]:\n",
    "                print(\"GPT claims that %03d + %03d = %03d (gt is %03d; %s)\" \n",
    "                      % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i], judge))\n",
    "        \n",
    "        if max_batches >= 0 and b+1 >= max_batches:\n",
    "            break\n",
    "\n",
    "    print(\"final score: %d/%d = %.2f%% correct\" % (np.sum(results), len(results), 100*np.mean(results)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "final score: 9000/9000 = 100.00% correct\n"
     ]
    }
   ],
   "source": [
    "# training set: how well did we memorize?\n",
    "give_exam(train_dataset, batch_size=1024, max_batches=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPT claims that 055 + 045 = 090 (gt is 100; NOPE)\n",
      "final score: 999/1000 = 99.90% correct\n"
     ]
    }
   ],
   "source": [
    "# test set: how well did we generalize?\n",
    "give_exam(test_dataset, batch_size=1024, max_batches=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# well that's amusing... our model learned everything except 55 + 45"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
